{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model.utils.data_generator import DataGenerator\n",
    "from model.img2seq import Img2SeqModel\n",
    "from model.utils.lr_schedule import LRSchedule\n",
    "from model.utils.general import Config\n",
    "from model.utils.text import Vocab\n",
    "from model.utils.image import greyscale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output=\"results/full4/\" # Dir for results and model weights\n",
    "data=output+\"data.json\" # Path to data json config\n",
    "vocab=output+\"vocab.json\" # Path to vocab json config\n",
    "training=output+\"training.json\" # Path to training json config\n",
    "model=output+\"model.json\" # Path to model json config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load configs\n",
    "dir_output = output\n",
    "config = Config([data, vocab, training, model])\n",
    "#config.save(dir_output)\n",
    "vocab = Vocab(config)\n",
    "\n",
    "# Load datasets\n",
    "train_set = DataGenerator(path_formulas=config.path_formulas_train,\n",
    "                          dir_images=config.dir_images_train,\n",
    "                          img_prepro=greyscale,\n",
    "                          max_iter=config.max_iter,\n",
    "                          bucket=config.bucket_train,\n",
    "                          path_matching=config.path_matching_train,\n",
    "                          max_len=config.max_length_formula,\n",
    "                          form_prepro=vocab.form_prepro)\n",
    "val_set = DataGenerator(path_formulas=config.path_formulas_val,\n",
    "                        dir_images=config.dir_images_val,\n",
    "                        img_prepro=greyscale,\n",
    "                        max_iter=config.max_iter,\n",
    "                        bucket=config.bucket_val,\n",
    "                        path_matching=config.path_matching_val,\n",
    "                        max_len=config.max_length_formula,\n",
    "                        form_prepro=vocab.form_prepro)\n",
    "\n",
    "# Define learning rate schedule\n",
    "n_batches_epoch = ((len(train_set) + config.batch_size - 1) //\n",
    "                   config.batch_size)\n",
    "lr_schedule = LRSchedule(lr_init=config.lr_init,\n",
    "                         start_decay=config.start_decay*n_batches_epoch,\n",
    "                         end_decay=config.end_decay*n_batches_epoch,\n",
    "                         end_warm=config.end_warm*n_batches_epoch,\n",
    "                         lr_warm=config.lr_warm,\n",
    "                         lr_min=config.lr_min)\n",
    "# Build model and train\n",
    "model = Img2SeqModel(config, dir_output, vocab)\n",
    "model.build_train(config)\n",
    "model.train(config, train_set, val_set, lr_schedule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
